{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d001ba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys  \n",
    "import json\n",
    "import torch \n",
    "import argparse\n",
    "import numpy as np\n",
    "from PIL import Image  \n",
    "from tqdm import tqdm\n",
    "from utils import model_gen\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer  \n",
    "\n",
    "ckpt_path = 'internlm/internlm-xcomposer2-vl-7b'\n",
    "tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=\"cuda\", trust_remote_code=True).eval().cuda().half()\n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ef25272",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = json.load(open('data/hallu/HallusionBench.json'))\n",
    "\n",
    "results = []\n",
    "for q in tqdm(samples):   \n",
    "    if q['filename'] == None:\n",
    "        continue  ### only evaluate the image part\n",
    "        \n",
    "    try: #### correct some file name problem\n",
    "        im_path = 'data/hallu/hallusion_bench/' + q['filename']\n",
    "        _ = Image.open(im_path)\n",
    "    except:\n",
    "        if q['filename'].find('png') != -1:\n",
    "            im_path = 'data/hallu/hallusion_bench/' + q['filename'].replace('png', 'PNG')\n",
    "        else:\n",
    "            im_path = 'data/hallu/hallusion_bench/' + q['filename'].replace('PNG', 'png')\n",
    "             \n",
    "    text = '[UNUSED_TOKEN_146]user\\n{}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n'.format(q['question'])\n",
    "    with torch.cuda.amp.autocast(): \n",
    "        response = model_gen(model, text, im_path, padding=True)\n",
    "    ans = '1' if response.lower().find('yes')!=-1 else '0'\n",
    "    results.append(ans == q['gt_answer'])\n",
    "    q.append(ans)\n",
    "    \n",
    "print (\"aAcc:\", np.mean(results))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e975f68e",
   "metadata": {},
   "outputs": [],
   "source": [
    "qlist = {}\n",
    "for r in samples:\n",
    "    key = \"_\".join([r[\"category\"], r[\"subcategory\"], str(r[\"set_id\"]), str(r[\"question_id\"])])\n",
    "    try:\n",
    "        qlist[key].append(r['answer'] == r['gt_answer'])\n",
    "    except:\n",
    "        qlist[key] = [r['answer'] == r['gt_answer']]\n",
    "out = []\n",
    "for q, v in qlist.items(): \n",
    "    out.append(min(v))\n",
    "print (\"qAcc:\", np.mean(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08f539d",
   "metadata": {},
   "outputs": [],
   "source": [
    "qlist = {}\n",
    "for r in samples:\n",
    "    key = \"_\".join([r[\"category\"], r[\"subcategory\"], str(r[\"set_id\"]), str(r[\"figure_id\"])])\n",
    "    try:\n",
    "        qlist[key].append(r['answer'] == r['gt_answer'])\n",
    "    except:\n",
    "        qlist[key] = [r['answer'] == r['gt_answer']]\n",
    "out = []\n",
    "for q, v in qlist.items(): \n",
    "    out.append(min(v))\n",
    "print (\"fAcc:\", np.mean(out))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
